{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fb7d0ced",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from selfcheckgpt.modeling_mqag import MQAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cf7eaa7f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x2ab879064250>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d95db1dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "03a5b722",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MQAG (race) initialized to cuda\n"
     ]
    }
   ],
   "source": [
    "mqag_model = MQAG(\n",
    "    g1_model_type='race', # race (more abstractive), squad (more extractive)\n",
    "    device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c96587f",
   "metadata": {},
   "source": [
    "# MQAG Score: `score`\n",
    "- Using MQAG for assessing the consistency between document and summary as described in: https://arxiv.org/abs/2301.12307"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5b71a50a",
   "metadata": {},
   "outputs": [],
   "source": [
    "document = r\"\"\"\n",
    "World number one Novak Djokovic says he is hoping for a \"positive decision\" to allow him\n",
    "to play at Indian Wells and the Miami Open next month. The United States has extended\n",
    "its requirement for international visitors to be vaccinated against Covid-19. Proof of vaccination\n",
    "will be required to enter the country until at least 10 April, but the Serbian has previously\n",
    "said he is unvaccinated. The 35-year-old has applied for special permission to enter the country.\n",
    "Indian Wells and the Miami Open - two of the most prestigious tournaments on the tennis calendar\n",
    "outside the Grand Slams - start on 6 and 20 March respectively. Djokovic says he will return to\n",
    "the ATP tour in Dubai next week after claiming a record-extending 10th Australian Open title\n",
    "and a record-equalling 22nd Grand Slam men's title last month.\"\"\".replace(\"\\n\", \"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "00e8629c",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = \"Djokvic might be allowed to play in the US next month. Djokovic will play in Qatar next week after winning his 5th Grand slam.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a36a88f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initialized Generation\n",
      "Initialized Answering\n",
      "Q1: What is the best title of this passage?\n",
      "(1) [P(.|cand)=16.74%]\t[P(.|ref)=34.21%]\tDjokovic playing in Qatar\n",
      "(2) [P(.|cand)=16.42%]\t[P(.|ref)=28.77%]\tDjokvic playing in Qatar\n",
      "(3) [P(.|cand)=50.09%]\t[P(.|ref)=2.81%]\tDjokovic's 5th Grand slam\n",
      "(4) [P(.|cand)=16.74%]\t[P(.|ref)=34.21%]\tDjokovic playing in Qatar\n",
      "-------------------------------------------------------------------------------\n",
      "Q2: Djokovic won his 5th Grand Slam _.\n",
      "(1) [P(.|cand)=39.97%]\t[P(.|ref)=36.00%]\tin Qatar\n",
      "(2) [P(.|cand)=7.03%]\t[P(.|ref)=19.41%]\tin China\n",
      "(3) [P(.|cand)=20.08%]\t[P(.|ref)=18.32%]\tin Germany\n",
      "(4) [P(.|cand)=32.91%]\t[P(.|ref)=26.26%]\tin England\n",
      "-------------------------------------------------------------------------------\n",
      "Q3: Djokovic might play in the US next month because _.\n",
      "(1) [P(.|cand)=5.21%]\t[P(.|ref)=93.53%]\the is allowed to play there\n",
      "(2) [P(.|cand)=47.24%]\t[P(.|ref)=1.64%]\the won his 5th Grand slam\n",
      "(3) [P(.|cand)=0.31%]\t[P(.|ref)=3.19%]\the won’t play in Qatar\n",
      "(4) [P(.|cand)=47.24%]\t[P(.|ref)=1.64%]\the won his 5th Grand slam\n",
      "-------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "score = mqag_model.score(candidate=summary, reference=document, num_questions=3, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "822ceb65",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KL-div    = 1.3974249426743093\n",
      "Counting  = 0.6666666666666666\n",
      "Hellinger = 0.37552807852247866\n",
      "Total Var = 0.4932687332232793\n"
     ]
    }
   ],
   "source": [
    "print(\"KL-div    =\", score['kl_div'])\n",
    "print(\"Counting  =\", score['counting'])\n",
    "print(\"Hellinger =\", score['hellinger'])\n",
    "print(\"Total Var =\", score['total_variation'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63b897c4",
   "metadata": {},
   "source": [
    "# Multiple-choice Question Generation: `generate`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1faff359",
   "metadata": {},
   "outputs": [],
   "source": [
    "context = r\"\"\"\n",
    "World number one Novak Djokovic says he is hoping for a \"positive decision\" to allow him\n",
    "to play at Indian Wells and the Miami Open next month. The United States has extended\n",
    "its requirement for international visitors to be vaccinated against Covid-19. Proof of vaccination\n",
    "will be required to enter the country until at least 10 April, but the Serbian has previously\n",
    "said he is unvaccinated. The 35-year-old has applied for special permission to enter the country.\n",
    "Indian Wells and the Miami Open - two of the most prestigious tournaments on the tennis calendar\n",
    "outside the Grand Slams - start on 6 and 20 March respectively. Djokovic says he will return to\n",
    "the ATP tour in Dubai next week after claiming a record-extending 10th Australian Open title\n",
    "and a record-equalling 22nd Grand Slam men's title last month.\"\"\".replace(\"\\n\", \"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "99bdef7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "------------------------------------\n",
      "Q0: How many Grand Slam men's titles has Djokovic won?\n",
      "A: 22\n",
      "B: 10\n",
      "C: 20\n",
      "D: 24\n",
      "------------------------------------\n",
      "Q1: Why is Djokovic hoping for a \"positive decision\"?\n",
      "A: To be allowed to play at Indian Wellsand the Miami Open.\n",
      "B: To play at the Miami Open.\n",
      "C: To apply to stay at the US until at least 10 April.\n",
      "D: To have his proof of vaccination.\n",
      "------------------------------------\n",
      "Q2: What does the author want to tell readers?\n",
      "A: Djokovic won't leave the ATP tour in Dubai.\n",
      "B: Djokovic will play at Indian Wells and the Miami Open next month.\n",
      "C: Djokovic won't enter the country until at least 10 April.\n",
      "D: The United States doesn't need foreign visitors to be vaccinated again.\n"
     ]
    }
   ],
   "source": [
    "questions = mqag_model.generate(context=context, do_sample=True, num_questions=3)\n",
    "for i, question_item in enumerate(questions):\n",
    "    print(\"------------------------------------\")\n",
    "    print(f\"Q{i}: {question_item['question']}\")\n",
    "    print(f\"A: {question_item['options'][0]}\")\n",
    "    print(f\"B: {question_item['options'][1]}\")\n",
    "    print(f\"C: {question_item['options'][2]}\")\n",
    "    print(f\"D: {question_item['options'][3]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79225f79",
   "metadata": {},
   "source": [
    "# Multiple-choice Question Answering: `answer`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "03522d01",
   "metadata": {},
   "outputs": [],
   "source": [
    "context = r\"\"\"Chelsea's mini-revival continued with a third victory in a row as they consigned struggling Leicester City to a fifth consecutive defeat.\n",
    "Buoyed by their Champions League win over Borussia Dortmund, Chelsea started brightly and Ben Chilwell volleyed in from a tight angle against his old club.\n",
    "Chelsea's Joao Felix and Leicester's Kiernan Dewsbury-Hall hit the woodwork in the space of two minutes, then Felix had a goal ruled out by the video assistant referee for offside.\n",
    "Patson Daka rifled home an excellent equaliser after Ricardo Pereira won the ball off the dawdling Felix outside the box.\n",
    "But Kai Havertz pounced six minutes into first-half injury time with an excellent dinked finish from Enzo Fernandez's clever aerial ball.\n",
    "Mykhailo Mudryk thought he had his first goal for the Blues after the break but his effort was disallowed for offside.\n",
    "Mateo Kovacic sealed the win as he volleyed in from Mudryk's header.\n",
    "The sliding Foxes, who ended with 10 men following Wout Faes' late dismissal for a second booking, now just sit one point outside the relegation zone.\n",
    "\"\"\".replace('\\n', ' ')\n",
    "\n",
    "question = \"Who had a goal ruled out for offside?\"\n",
    "options  = ['Ricardo Pereira', 'Ben Chilwell', 'Joao Felix', 'The Foxes']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e0936b0f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.00145158 0.00460854 0.99049687 0.00344299]\n"
     ]
    }
   ],
   "source": [
    "questions = [{'question': question, 'options': options}]\n",
    "probs = mqag_model.answer(questions=questions, context=context)\n",
    "print(probs[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
